{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K7t_qlnTVWWS"
      },
      "outputs": [],
      "source": [
        "# have you installed snn torch?\n",
        "# %pip install snntorch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YpJrTa3XVaSQ"
      },
      "outputs": [],
      "source": [
        "import snntorch as snn\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {},
      "outputs": [],
      "source": [
        "DATADIR = \"/tmp/data\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "kbHJ827iVcYY"
      },
      "outputs": [
        {
          "ename": "AttributeError",
          "evalue": "module 'torch' has no attribute '_six'",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
            "\u001b[0;32m/tmp/ipykernel_2028996/4075440975.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;31m# # note that a default transform is already applied to keep things easy\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msnntorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspikevision\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mspikedata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mtrain_ds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mspikedata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDVSGesture\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/tmp/data/dvsgesture/\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdt\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_steps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m500\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# ds: spatial compression; dt: temporal compressiondvs_test\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m \u001b[0mtest_ds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mspikedata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDVSGesture\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/tmp/data/dvsgesture/\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdt\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_steps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1800\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m~/.local/lib/python3.10/site-packages/snntorch/spikevision/spikedata/dvs_gesture.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root, train, transform, target_transform, download_and_create, num_steps, dt, ds, return_meta, time_shuffle)\u001b[0m\n\u001b[1;32m    232\u001b[0m             \u001b[0mtarget_transform\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCompose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mRepeat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_steps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtoOneHot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m11\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    233\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 234\u001b[0;31m         super(DVSGesture, self).__init__(\n\u001b[0m\u001b[1;32m    235\u001b[0m             \u001b[0mroot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mroot\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"/\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhdf5_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    236\u001b[0m             \u001b[0mtransform\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m~/.local/lib/python3.10/site-packages/snntorch/spikevision/neuromorphic_dataset.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root, transforms, transform, target_transform, transform_train, transform_test, target_transform_train, target_transform_test)\u001b[0m\n\u001b[1;32m    143\u001b[0m         \u001b[0mtarget_transform_test\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    144\u001b[0m     ):\n\u001b[0;32m--> 145\u001b[0;31m         \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_six\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstring_classes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    146\u001b[0m             \u001b[0mroot\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexpanduser\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    147\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroot\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mroot\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mAttributeError\u001b[0m: module 'torch' has no attribute '_six'"
          ]
        }
      ],
      "source": [
        "# from snntorch.spikevision import spikedata \n",
        "# # note that a default transform is already applied to keep things easy\n",
        "from snntorch.spikevision import spikedata\n",
        "train_ds = spikedata.DVSGesture(\"/tmp/data/dvsgesture/\", train=True, dt=1000, num_steps=500, ds=1)  # ds: spatial compression; dt: temporal compressiondvs_test\n",
        "test_ds = spikedata.DVSGesture(\"/tmp/data/dvsgesture/\", train=False, dt=1000, num_steps=1800, ds=1)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import tonic\n",
        "import tonic.transforms as transforms\n",
        "\n",
        "sensor_size = tonic.datasets.DVSGesture.sensor_size\n",
        "\n",
        "# Denoise removes isolated, one-off events\n",
        "# time_window\n",
        "frame_transform = transforms.Compose([transforms.Denoise(filter_time=10000),\n",
        "                                      transforms.ToFrame(sensor_size=sensor_size,\n",
        "                                                         time_window=1000)\n",
        "                                     ])\n",
        "\n",
        "train_ds = tonic.datasets.DVSGesture(save_to=DATADIR, transform=frame_transform, train=True)\n",
        "test_ds = tonic.datasets.DVSGesture(save_to=DATADIR, transform=frame_transform, train=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 163
        },
        "id": "AWQTG9lhVmKY",
        "outputId": "3d607d3e-b416-4a47-e61e-4d562ed52dba"
      },
      "outputs": [],
      "source": [
        "test_ds"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "buI7BN_JVvSb"
      },
      "source": [
        "## DataLoaders"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from torch.utils.data import DataLoader\n",
        "\n",
        "train_dl = DataLoader(train_ds, shuffle=True, batch_size=64)\n",
        "test_dl = DataLoader(test_ds, shuffle=False, batch_size=64)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "print('the number of items in the dataset is', len(train_dl.dataset))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Play with Data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# get a feel for the data\n",
        "i_item = 42 # random index into a sample\n",
        "data, label = train_dl.dataset[i_item]\n",
        "import torch\n",
        "data = torch.Tensor(data)\n",
        "\n",
        "print('The data sample has size', data.shape)\n",
        "print(f\"in case you're blind AF, the target is: {label} ({train_ds.classes[label]})\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "train_ds.classes"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Visualize"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import snntorch.spikeplot as splt\n",
        "from IPython.display import HTML, display\n",
        "import numpy as np\n",
        "\n",
        "# flatten on-spikes and off-spikes into one channel\n",
        "# a = (train_dl.dataset[n][0][:, 0] + train_dl.dataset[n][0][:, 1])\n",
        "a = (data[:300, 0, :, :] - data[:300, 1, :, :])\n",
        "# a = np.swapaxes(a, 0, -1)\n",
        "#  Plot\n",
        "fig, ax = plt.subplots()\n",
        "anim = splt.animator(a, fig, ax, interval=200)\n",
        "HTML(anim.to_html5_video())\n",
        "# anim.save('nmnist_animation.mp4', writer = 'ffmpeg', fps=50)  "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iDDqdJ1YWBns"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import torch\n",
        "plt.imshow(torch.sum(data[:300, 0,:,:], axis=0), cmap='hot')\n",
        "plt.colorbar()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "DVS_Gesture.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
